import torchvision.transforms as transforms
import torchvision
import os
from torch.utils.data import Dataset
import numpy as np
import torch
import random


class DVSCifar10(Dataset):
    def __init__(self, root, train=True, transform=False, target_transform=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.resize = transforms.Resize(size=(48, 48))  # 48 48
        self.tensorx = transforms.ToTensor()
        self.imgx = transforms.ToPILImage()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data, target = torch.load(self.root + '/{}.pt'.format(index))
        # print(data.shape)
        # if self.train:
        new_data = []
        for t in range(data.size(-1)):
            new_data.append(self.tensorx(self.resize(self.imgx(data[..., t]))))
        data = torch.stack(new_data, dim=0)
        if self.transform:
            flip = random.random() > 0.5
            if flip:
                data = torch.flip(data, dims=(3,))
            off1 = random.randint(-5, 5)
            off2 = random.randint(-5, 5)
            data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))

        if self.target_transform:
            target = self.target_transform(target)
        return data, target.long().squeeze(-1)

    def __len__(self):
        return len(os.listdir(self.root))


def build_dvscifar(path):
    train_path = path + '/train'
    val_path = path + '/test'
    train_dataset = DVSCifar10(root=train_path, transform=True)
    val_dataset = DVSCifar10(root=val_path)

    return train_dataset, val_dataset



def dataload(data_path, data_name):
    """
    Normalize at this function.
    """
    # root = '../../data/'
    if data_name == 'MNIST':
        trans = transforms.Compose([transforms.ToTensor(), ])
        train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=True,
                                                   transform=trans)
        test_dataset = torchvision.datasets.MNIST(root=data_path, train=False, download=True,
                                                  transform=trans)
    elif data_name == 'CIFAR10':
        trans_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trans_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True,
                                                     transform=trans_train)
        test_dataset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True,
                                                    transform=trans_test)
    elif data_name == 'CIFAR100':
        trans_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        trans_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True,
                                                      transform=trans_train)
        test_dataset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True,
                                                     transform=trans_test)
    elif data_name == 'IMAGE':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        root = '/data/linh/ImageNet'
        train_root = os.path.join(root, 'train')
        val_root = os.path.join(root, 'val')
        train_dataset = torchvision.datasets.ImageFolder(
            train_root,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        )
        test_dataset = torchvision.datasets.ImageFolder(
            val_root,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
        )
    elif data_name == 'AUGDVS':
        # train_dataset, test_dataset = build_dvscifar('/data_smr/dataset/cifar-dvs')
        train_dataset, test_dataset = build_dvscifar(data_path)
    else:
        raise ValueError("Can't find current data_name in `dataload` method.")

    return train_dataset, test_dataset


